import json
import random
from typing import Any, Callable, Dict, List, Optional

from corpus import Corpus, NltkCorpus
from logger_config import logger
from processor import WordCollectionProcessor
from related_words import DataMuseRelatedWords
from validators import EmbeddingsValidator, RelatedWordsPair, SimilarityTuple


class IncludesEvalTemplate:
    samples = []

    def create_sample(
        self,
        system_message: str,
        user_message: str,
        ideal_answer: str,
    ) -> Dict[str, str | List[str]]:
        sample = {
            "input": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message},
            ],
            "ideal": f"[{ideal_answer}]",
        }
        self.samples.append(sample)
        logger.debug(f"Sample created: {sample}")
        logger.info(f"{user_message} -> {ideal_answer}")
        return sample

    def export_to_jsonl(self, filename: str = "samples.jsonl") -> None:
        with open(filename, "w") as f:
            for sample in self.samples:
                f.write(json.dumps(sample) + "\n")


def generate_additional_choices(
    word_association_pair: RelatedWordsPair,
    corpus: Corpus,
    num_choices: int = 5,
    shuffle_choices: bool = False,
) -> List[str]:
    # Create a new list without the target word and related words
    correct_answer = word_association_pair.word
    new_corpus = [
        word
        for word in corpus
        if word != correct_answer and word not in word_association_pair.related_words
    ]

    validator = EmbeddingsValidator(0.75)
    correct_answer_embedding = validator.get_embeddings(correct_answer)[0]
    related_words_embeddings = validator.get_embeddings(word_association_pair.related_words)[0]
    correct_answer_score = validator.calculate_cosine_similarity(
        correct_answer_embedding.vector, related_words_embeddings.vector
    )
    choices = []
    while len(choices) < num_choices:
        choice = random.sample(new_corpus, 1)[0]
        choice_embedding = validator.get_embeddings(choice)[0]
        similarity = validator.calculate_cosine_similarity(
            choice_embedding.vector, related_words_embeddings.vector
        )
        if similarity < correct_answer_score:
            choices.append(choice)
        if not new_corpus:
            raise ValueError("Not enough valid words in corpus to generate choices.")
    choices.append(correct_answer)
    if shuffle_choices:
        random.shuffle(choices)
    return choices


def generate_word_association_system_message(
    word_association_pair: RelatedWordsPair,
    parts_of_speech_choices: Optional[List[str]] = None,
) -> str:
    related_words = word_association_pair.related_words.split(", ")
    num_words = len(related_words)
    word_length = len(word_association_pair.word)

    message_parts = [
        "We are going to play a game of word association. I want you to guess the secret word.",
        f"I will give you {num_words} words, the secret word is related to all {num_words} of these words.",
        f"The secret word is {word_length} characters long.",
    ]

    if parts_of_speech_choices:
        message_parts.append(
            f"The secret word is one of the following parts of speech: {parts_of_speech_choices}."
        )

    message_parts.append(
        "What is the secret word? Before answering, reason in a step-by-step manner "
        "as to get the right answer, then conclude with the answer in the following format: "
        "The secret word is: [<secret-word.lower()>] because <reasoning>"
    )

    system_message = " ".join(message_parts)
    logger.debug(f"System message: {system_message}")
    return system_message


def generate_word_association_user_message(
    word_association_pair: RelatedWordsPair, corpus: Corpus
) -> str:
    choices = generate_additional_choices(word_association_pair, corpus)
    # I have chosen to join the list into a string instead of just using the list because it uses fewer tokens
    user_message = (
        f"Here is a list of the related words: [{word_association_pair.related_words}]. Here is a list of "
        f"your options: [{', '.join(choices)}]. What is the secret word?"
    )
    logger.debug(f"User Message: {user_message}")
    return user_message


def taboo_clue_guesser_system_message() -> None:
    """This function is used to generate the system message for the taboo clue guesser eval. This will be similar to
    the word association game, but the task will be to guess the secret word based on a paragraph generated by an LLM
    where use of the related words list is forbidden instead of given."""
    raise NotImplementedError


def taboo_clue_giver_system_message() -> None:
    """This function is used to generate the system message for the taboo clue giver eval. In this case the LLMs task
    will be to generate a paragraph that will help a guesser guess the secret word. The limiting rule will be in line
    with the game taboo's rules, where use of the related words list is forbidden instead of given. This eval will be a
    ModelGradedEval."""
    raise NotImplementedError


def main(
    corpus: Corpus,
    related_words_length: int,
    max_samples: int = -1,
    export_file: Optional[str] = None,
    *filters: Callable[[Any], Any],
) -> None:
    eval_factory = IncludesEvalTemplate()

    word_association_pairs: List[RelatedWordsPair] = []
    # Get related words for each word in the filtered corpus
    corpus = sorted(set(corpus))
    for word in corpus:
        related_words = DataMuseRelatedWords(word)

        # Define the processor which will perform the filteration on the related words
        # (currently the only implemented processor works on both corpus and related words)
        related_processor = WordCollectionProcessor(related_words)

        # Filter the related words to remove 'words' that are actually phrases
        related_processor.str_max_word_count_filter(1)
        # Filter the related words to remove words that are too long
        related_processor.sub_word_filter(word)

        # Apply additional filter functions
        for filter_func in filters:
            related_words = filter_func(related_words)

        related_words = related_processor.words.words
        if len(related_words) >= related_words_length:
            related_words = related_words[:related_words_length]
            logger.info(f"Word: {word}, Related Words: {related_words}")
            word_association_pairs.append(RelatedWordsPair(word, ", ".join(related_words)))
            # generate the system message for each word association
        else:
            logger.info(
                f"Word: {word}, Related Words: {related_words}, Skipped - Not Enough Related Words"
            )

    validator = EmbeddingsValidator(0.75)
    similarities: List[SimilarityTuple] = validator.validate(word_association_pairs)

    valid_samples: List[RelatedWordsPair] = [
        word_association_pair
        for word_association_pair, similarity, similarity_score in similarities
        if similarity
    ]
    logger.info(f"Total Sample: {len(word_association_pairs)} Valid Samples: {len(valid_samples)}")
    for word_association_pair in valid_samples:
        system_message = generate_word_association_system_message(word_association_pair)
        user_message = generate_word_association_user_message(word_association_pair, corpus)
        eval_factory.create_sample(system_message, user_message, word_association_pair.word)
        # If the maximum number of samples have been created, break the loop
        if max_samples != -1 and len(eval_factory.samples) >= max_samples:
            break

    if export_file is None:
        export_file = f"related_words_{related_words_length}.jsonl"
    eval_factory.export_to_jsonl(filename=export_file)


if __name__ == "__main__":
    # define the baseline corpus
    corpus = NltkCorpus("words")
    # define the processor which will perform the filteration of the baseline corpus
    processor = WordCollectionProcessor(corpus)

    # Filter the baseline corpus against frequency distribution of another corpus
    freq_filter_corpus = NltkCorpus("brown")
    processor.frequency_filter(thresholds=(50, 10000), filter_corpus=freq_filter_corpus)

    # Filter the baseline corpus against length and parts of speech
    processor.char_length_filter(length_bounds=(5, 5))
    processor.parts_of_speech_filter(["NN", "VB"])
    filtered_corpus = processor.words

    # Generate the evals
    main(filtered_corpus, related_words_length=5, max_samples=-1)
